//------------------------------------------------------------------------------
//  shadercodegenerator.cc
//  (C) 2007 Radon Labs GmbH
//------------------------------------------------------------------------------
#include "stdneb.h"
#include "shadercodegenerator.h"
#include "io/ioserver.h"
#include "shadernode.h"
#include "shaderslot.h"
#include "shadersamplermanager.h"
#include "shaderfragmentmanager.h"

namespace Tools
{
using namespace Util;
using namespace IO;

//------------------------------------------------------------------------------
/**
*/
ShaderCodeGenerator::ShaderCodeGenerator() :
    dumpShaderStructure(false)
{
    // empty
}

//------------------------------------------------------------------------------
/**
    This generates the shader source code to "basedirectory/intermediate/shaders".
*/
bool
ShaderCodeGenerator::GenerateSourceCode(const URI& projDirectory, const Ptr<Shader>& shd)
{
    n_assert(projDirectory.IsValid());
    n_assert(shd.isvalid());
    ShaderFragmentManager* fragManager = ShaderFragmentManager::Instance();

    this->shader = shd;

    // make sure the target directory exists
    URI shaderPath = projDirectory;
    shaderPath.AppendLocalPath("intermediate/shaders");
    IoServer::Instance()->CreateDirectory(shaderPath);

    // create the target stream and attach a text writer
    String outFileName = shader->GetName();
    outFileName.Append("_");
    outFileName.Append(String::Concatenate(fragManager->GetActiveGroupFragments(), "_"));
    outFileName.Append(".fx");

    shaderPath.AppendLocalPath(outFileName);
    Ptr<Stream> stream = IoServer::Instance()->CreateStream(shaderPath);
    Ptr<TextWriter> textWriter = TextWriter::Create();
    textWriter->SetStream(stream);
    if (textWriter->Open())
    {
        this->WriteFileHeader(textWriter);
        if (this->dumpShaderStructure)
        {
            this->WriteShaderStructureDump(textWriter);
        }
        this->WriteInputOutputDeclarations(textWriter);
        this->WriteConstantDeclarations(textWriter);
        if (!this->WriteSamplerDeclarations(textWriter))
        {
            return false;
        }
        this->WriteFragmentFunctions(textWriter);
        this->WriteVertexShader(textWriter);
        this->WritePixelShader(textWriter);
        textWriter->Close();
        return true;
    }
    else
    {
        return false;
    }
}

//------------------------------------------------------------------------------
/**
*/
void
ShaderCodeGenerator::WriteFileHeader(const Ptr<TextWriter>& textWriter)
{
    n_assert(this->shader.isvalid());
    textWriter->WriteString("//------------------------------------------------------------------------------\n");
    textWriter->WriteFormatted("// %s.fx\n", this->shader->GetName().AsCharPtr());
    textWriter->WriteFormatted("// Generated by nsc3.exe.\n");
    textWriter->WriteString("//------------------------------------------------------------------------------\n");
    textWriter->WriteString("\n");
}

//------------------------------------------------------------------------------
/**
*/
void
ShaderCodeGenerator::WriteShaderStructureDump(const Ptr<TextWriter>& textWriter)
{
    n_assert(this->shader.isvalid());
    textWriter->WriteString("//------------------------------------------------------------------------------\n");
    textWriter->WriteString("/**\n");
    textWriter->WriteString("    Shader Structure Dump:\n\n");
    this->shader->DebugDumpShaderNodes(textWriter);
    textWriter->WriteString("*/\n\n");
}

//------------------------------------------------------------------------------
/**
*/
void
ShaderCodeGenerator::WriteInputOutputDeclarations(const Ptr<TextWriter>& textWriter)
{
    n_assert(this->shader.isvalid());

    textWriter->WriteString("//------------------------------------------------------------------------------\n");
    textWriter->WriteString("/**\n");
    textWriter->WriteString("    Input/output declarations.\n");
    textWriter->WriteString("*/\n");
    
    // write vertex shader input declaration
    textWriter->WriteString("struct vsInput\n");
    textWriter->WriteString("{\n");
    const Ptr<ShaderNode>& vertexNode = this->shader->GetShaderNodes()["Vertex"];
    IndexT slotIndex;
    for (slotIndex = 0; slotIndex < vertexNode->GetOutputSlots().Size(); slotIndex++)
    {
        const Ptr<ShaderSlot>& slot = vertexNode->GetOutputSlots().ValueAtIndex(slotIndex);
        const String& dataType = slot->GetDataType();
        const String& name = slot->GetName();
        const String& semantics = slot->GetSemantics();
        textWriter->WriteFormatted("    %s %s : %s;\n", dataType.AsCharPtr(), name.AsCharPtr(), semantics.AsCharPtr());
    }
    textWriter->WriteString("};\n\n");

    // write interpolator declaration
    textWriter->WriteString("struct vsOutput\n");
    textWriter->WriteString("{\n");
    const Ptr<ShaderNode>& interpNode = this->shader->GetShaderNodes()["Interpolator"];
    for (slotIndex = 0; slotIndex < interpNode->GetInputSlots().Size(); slotIndex++)
    {
        const Ptr<ShaderSlot>& slot = interpNode->GetInputSlots().ValueAtIndex(slotIndex);
        const String& dataType = slot->GetDataType();
        const String& name = slot->GetName();
        const String& semantics = slot->GetSemantics();
        textWriter->WriteFormatted("    %s %s : %s;\n", dataType.AsCharPtr(), name.AsCharPtr(), semantics.AsCharPtr());
    }
    textWriter->WriteString("};\n\n");

    // writer pixel shader output declaration
    textWriter->WriteString("struct psOutput\n");
    textWriter->WriteString("{\n");
    const Ptr<ShaderNode>& resultNode = this->shader->GetShaderNodes()["Result"];
    for (slotIndex = 0; slotIndex < resultNode->GetInputSlots().Size(); slotIndex++)
    {
        const Ptr<ShaderSlot>& slot = resultNode->GetInputSlots().ValueAtIndex(slotIndex);
        const String& dataType = slot->GetDataType();
        const String& name = slot->GetName();
        const String& semantics = slot->GetSemantics();
        textWriter->WriteFormatted("    %s %s : %s;\n", dataType.AsCharPtr(), name.AsCharPtr(), semantics.AsCharPtr());
    }
    textWriter->WriteString("};\n\n");
}   

//------------------------------------------------------------------------------
/**
*/
void
ShaderCodeGenerator::WriteConstantDeclarations(const Ptr<TextWriter>& textWriter)
{
    n_assert(this->shader.isvalid());

    textWriter->WriteString("//------------------------------------------------------------------------------\n");
    textWriter->WriteString("/**\n");
    textWriter->WriteString("    Shader constants.\n");
    textWriter->WriteString("*/\n");

    // write shader constant declarations
    const Ptr<ShaderNode>& constNodes = this->shader->GetShaderNodes()["Constant"];
    IndexT slotIndex;
    for (slotIndex = 0; slotIndex < constNodes->GetOutputSlots().Size(); slotIndex++)
    {
        const Ptr<ShaderSlot>& slot = constNodes->GetOutputSlots().ValueAtIndex(slotIndex);
        textWriter->WriteFormatted("%s %s;\n", slot->GetDataType().AsCharPtr(), slot->GetName().AsCharPtr());
    }

    // write shared constant declarations
    const Ptr<ShaderNode>& sharedNodes = this->shader->GetShaderNodes()["Shared"];
    for (slotIndex = 0; slotIndex < sharedNodes->GetOutputSlots().Size(); slotIndex++)
    {
        const Ptr<ShaderSlot>& slot = sharedNodes->GetOutputSlots().ValueAtIndex(slotIndex);
        textWriter->WriteFormatted("shared %s %s;\n", slot->GetDataType().AsCharPtr(), slot->GetName().AsCharPtr());
    }
    textWriter->WriteString("\n\n");
}

//------------------------------------------------------------------------------
/**
*/
bool
ShaderCodeGenerator::WriteSamplerDeclarations(const Ptr<TextWriter>& textWriter)
{
    n_assert(this->shader.isvalid());

    textWriter->WriteString("//------------------------------------------------------------------------------\n");
    textWriter->WriteString("/**\n");
    textWriter->WriteString("    Texture samplers.\n");
    textWriter->WriteString("*/\n");

    // write shader constant declarations
    ShaderSamplerManager* samplerManager = ShaderSamplerManager::Instance();
    const Ptr<ShaderNode>& samplerNodes = this->shader->GetShaderNodes()["Sampler"];
    IndexT slotIndex;
    for (slotIndex = 0; slotIndex < samplerNodes->GetOutputSlots().Size(); slotIndex++)
    {
        const Ptr<ShaderSlot>& slot = samplerNodes->GetOutputSlots().ValueAtIndex(slotIndex);
        const String& samplerName = slot->GetName();
        
        // make sure the sampler exists
        if (!samplerManager->HasSampler(samplerName))
        {
            n_printf("Invalid sampler name '%s' (sampler not declared)!\n", samplerName.AsCharPtr());
            return false;
        }
        const Ptr<ShaderSampler>& sampler = samplerManager->GetSampler(samplerName);

        // write texture name
        textWriter->WriteFormatted("texture %s;\n", sampler->GetTextureParamName().AsCharPtr());

        // write sampler definition
        textWriter->WriteFormatted("sampler %s = sampler_state\n", samplerName.AsCharPtr());
        textWriter->WriteString("{\n");
        textWriter->WriteFormatted("    Texture = <%s>\n", sampler->GetTextureParamName().AsCharPtr());
        if (sampler->GetAddrU().IsValid())
        {
            textWriter->WriteFormatted("    AddressU = %s\n", sampler->GetAddrU().AsCharPtr());
        }
        if (sampler->GetAddrV().IsValid())
        {
            textWriter->WriteFormatted("    AddressV = %s\n", sampler->GetAddrV().AsCharPtr());
        }
        if (sampler->GetAddrW().IsValid())
        {
            textWriter->WriteFormatted("    AddressW = %s\n", sampler->GetAddrW().AsCharPtr());
        }
        if (sampler->GetBorderColor().IsValid())
        {
            textWriter->WriteFormatted("    BorderColor = %s\n", sampler->GetBorderColor().AsCharPtr());
        }
        if (sampler->GetMagFilter().IsValid())
        {
            textWriter->WriteFormatted("    MagFilter = %s\n", sampler->GetMagFilter().AsCharPtr());
        }
        if (sampler->GetMinFilter().IsValid())
        {
            textWriter->WriteFormatted("    MinFilter = %s\n", sampler->GetMinFilter().AsCharPtr());
        }
        if (sampler->GetMipFilter().IsValid())
        {
            textWriter->WriteFormatted("    MipFilter = %s\n", sampler->GetMipFilter().AsCharPtr());
        }
        if (sampler->GetMaxAnisotropy().IsValid())
        {
            textWriter->WriteFormatted("    MaxAnisotropy = %s\n", sampler->GetMaxAnisotropy().AsCharPtr());
        }
        if (sampler->GetMaxMipLevel().IsValid())
        {
            textWriter->WriteFormatted("    MaxMipLevel = %s\n", sampler->GetMaxMipLevel().AsCharPtr());
        }
        if (sampler->GetMipLodBias().IsValid())
        {
            textWriter->WriteFormatted("    MipMapLodBoas = %s\n", sampler->GetMipLodBias().AsCharPtr());
        }
        if (sampler->GetSRGBTexture().IsValid())
        {
            textWriter->WriteFormatted("    SRGBTexture = %s\n", sampler->GetSRGBTexture().AsCharPtr());
        }
        textWriter->WriteString("}\n\n");
    }
    return true;
}

//------------------------------------------------------------------------------
/**
    Write source code for all shader fragment functions.
*/
void
ShaderCodeGenerator::WriteFragmentFunctions(const Ptr<TextWriter>& textWriter)
{
    n_assert(this->shader.isvalid());

    Array<Ptr<ShaderFragment>> fragments = this->shader->GatherShaderFragments();
    IndexT i;
    for (i = 0; i < fragments.Size(); i++)
    {
        if (fragments[i]->HasVertexShader())
        {
            textWriter->WriteString("//------------------------------------------------------------------------------\n");
            textWriter->WriteString("/**\n");
            textWriter->WriteFormatted("    '%s' Vertex Shader Function.\n", fragments[i]->GetName().AsCharPtr());
            textWriter->WriteString("*/\n");
            const Dictionary<String, ShaderParam>& inputs = fragments[i]->GetVertexShaderInputs();
            const Dictionary<String, ShaderParam>& outputs = fragments[i]->GetVertexShaderOutputs();
            const String& code = fragments[i]->GetVertexShaderCode();
            String funcName = fragments[i]->GetVertexShaderFunctionName();
            this->WriteFragmentFunction(textWriter, funcName, inputs, outputs, code);
        }
        if (fragments[i]->HasPixelShader())
        {
            textWriter->WriteString("//------------------------------------------------------------------------------\n");
            textWriter->WriteString("/**\n");
            textWriter->WriteFormatted("    '%s' Pixel Shader Function.\n", fragments[i]->GetName().AsCharPtr());
            textWriter->WriteString("*/\n");
            const Dictionary<String, ShaderParam>& inputs = fragments[i]->GetPixelShaderInputs();
            const Dictionary<String, ShaderParam>& outputs = fragments[i]->GetPixelShaderOutputs();
            const String& code = fragments[i]->GetPixelShaderCode();
            String funcName = fragments[i]->GetPixelShaderFunctionName();
            this->WriteFragmentFunction(textWriter, funcName, inputs, outputs, code);
        }
    }
}

//------------------------------------------------------------------------------
/**
    Write function source code for a single shader fragment.
*/
void
ShaderCodeGenerator::WriteFragmentFunction(const Ptr<TextWriter>& textWriter,
                                           const String& funcName,
                                           const Dictionary<String, ShaderParam>& inputs,
                                           const Dictionary<String, ShaderParam>& outputs,
                                           const String& code)
{
    textWriter->WriteFormatted("void %s(\n", funcName.AsCharPtr());

    // write input parameters
    if (inputs.Size() > 0)
    {
        IndexT i;
        for (i = 0; i < inputs.Size(); i++)
        {
            const ShaderParam& param = inputs.ValueAtIndex(i);
            String line;
            line.Format("    in %s %s", param.GetType().AsCharPtr(), param.GetName().AsCharPtr());
            if (i < (inputs.Size() - 1))
            {
                line.Append(",\n");
            }
            textWriter->WriteString(line);
        }
        if (outputs.Size() > 0)
        {
            textWriter->WriteString(",\n");
        }
    }

    // write output parameters
    if (outputs.Size() > 0)
    {
        IndexT i;
        for (i = 0; i < outputs.Size(); i++)
        {
            const ShaderParam& param = outputs.ValueAtIndex(i);
            String line;
            line.Format("    out %s %s", param.GetType().AsCharPtr(), param.GetName().AsCharPtr());
            if (i < (outputs.Size() - 1))
            {
                line.Append(",\n");
            }
            textWriter->WriteString(line);
        }
    }

    // write function body
    textWriter->WriteString(")\n{\n");
    Array<String> codeLines = code.Tokenize("\n");
    IndexT i;
    for (i = 0; i < codeLines.Size(); i++)
    {
        String line = codeLines[i];
        line.TrimLeft(" \t");
        if (line.IsValid())
        {
            textWriter->WriteString("    ");
            textWriter->WriteLine(line);
        }
    }
    textWriter->WriteString("}\n\n");
}

//------------------------------------------------------------------------------
/**
    This writes the input variables for a fragment call (vertex or pixel
    shader).
*/
void
ShaderCodeGenerator::WriteInputVariables(const Ptr<TextWriter>& textWriter, 
                                         const Ptr<ShaderNode>& shaderNode, 
                                         ShaderSlot::SlotType slotType)
{
    // for each input slot...
    const Dictionary<String, Ptr<ShaderSlot>>& slots = shaderNode->GetInputSlots();
    IndexT i;
    for (i = 0; i < slots.Size(); i++)
    {
        const Ptr<ShaderSlot>& slot = slots.ValueAtIndex(i);
        if (slot->GetSlotType() == slotType)
        {
            String leftHandSide;
            String rightHandSide;
            leftHandSide.Format("%s_%s", shaderNode->GetName().AsCharPtr(), slot->GetName().AsCharPtr());
            if (slot->GetConnections().Size() > 0)
            {
                n_assert(slot->GetConnections().Size() == 1);
                const Ptr<ShaderSlot>& uplinkSlot = slot->GetConnections()[0];
                const Ptr<ShaderNode>& uplinkNode = this->shader->GetShaderNodes()[uplinkSlot->GetNodeName()];
                if (uplinkNode->GetName() == "Vertex")
                {
                    rightHandSide.Format("vsIn.%s", uplinkSlot->GetName().AsCharPtr());
                }
                else if (uplinkNode->GetName() == "Interpolator")
                {
                    rightHandSide.Format("psIn.%s", uplinkSlot->GetName().AsCharPtr());
                }
                else if ((uplinkNode->GetName() == "Sampler") ||
                         (uplinkNode->GetName() == "Constant") ||
                         (uplinkNode->GetName() == "Shared"))
                {
                    rightHandSide.Format("%s", uplinkSlot->GetName().AsCharPtr());
                }
                else
                {
                    rightHandSide.Format("%s_%s", uplinkNode->GetName().AsCharPtr(), uplinkSlot->GetName().AsCharPtr());
                }
            }
            else
            {
                rightHandSide = ">>> SLOT NOT CONNECTED <<<";
            }
            textWriter->WriteFormatted("    %s %s = %s;\n", slot->GetDataType().AsCharPtr(), leftHandSide.AsCharPtr(), rightHandSide.AsCharPtr());
        }
    }
}

//------------------------------------------------------------------------------
/**
    Write shader output variables.
*/
void
ShaderCodeGenerator::WriteOutputVariables(const Ptr<TextWriter>& textWriter, 
                                          const Ptr<ShaderNode>& shaderNode,
                                          ShaderSlot::SlotType slotType)
{
    // for each output slot...
    const Dictionary<String, Ptr<ShaderSlot>>& slots = shaderNode->GetOutputSlots();
    IndexT i;
    for (i = 0; i < slots.Size(); i++)
    {
        const Ptr<ShaderSlot>& slot = slots.ValueAtIndex(i);
        if (slot->GetSlotType() == slotType)
        {
            textWriter->WriteFormatted("    %s %s_%s;\n", 
                slot->GetDataType().AsCharPtr(), 
                slot->GetNodeName().AsCharPtr(),
                slot->GetName().AsCharPtr());
        }
    }
}

//------------------------------------------------------------------------------
/**
    Write a vertex shader fragment function call.
*/
void
ShaderCodeGenerator::WriteVertexShaderFragmentCall(const Ptr<TextWriter>& textWriter, const Ptr<ShaderNode>& shaderNode)
{
    // get the node's fragment
    const Ptr<ShaderFragment>& frag = ShaderFragmentManager::Instance()->GetFragment(shaderNode->GetFragmentName());
    const Dictionary<String, ShaderParam>& inputs = frag->GetVertexShaderInputs();
    const Dictionary<String, ShaderParam>& outputs = frag->GetVertexShaderOutputs();
    String funcName = frag->GetVertexShaderFunctionName();

    textWriter->WriteFormatted("    %s(", funcName.AsCharPtr());
    if (inputs.Size() > 0)
    {
        IndexT i;
        for (i = 0; i < inputs.Size(); i++)
        {
            textWriter->WriteFormatted("%s_%s", shaderNode->GetName().AsCharPtr(), inputs.ValueAtIndex(i).GetName().AsCharPtr());
            if ((i < (inputs.Size() - 1)) || (outputs.Size() > 0))
            {
                textWriter->WriteString(", ");
            }
        }
    }
    if (outputs.Size() > 0)
    {
        IndexT i;
        for (i = 0; i < outputs.Size(); i++)
        {
            textWriter->WriteFormatted("%s_%s", shaderNode->GetName().AsCharPtr(), outputs.ValueAtIndex(i).GetName().AsCharPtr());
            if (i < (outputs.Size() - 1))
            {
                textWriter->WriteString(", ");
            }
        }
    }
    textWriter->WriteString(");\n");
}

//------------------------------------------------------------------------------
/**
    Write a pixel shader fragment function call.
*/
void
ShaderCodeGenerator::WritePixelShaderFragmentCall(const Ptr<TextWriter>& textWriter, const Ptr<ShaderNode>& shaderNode)
{
    // get the node's fragment
    const Ptr<ShaderFragment>& frag = ShaderFragmentManager::Instance()->GetFragment(shaderNode->GetFragmentName());
    const Dictionary<String, ShaderParam>& inputs = frag->GetPixelShaderInputs();
    const Dictionary<String, ShaderParam>& outputs = frag->GetPixelShaderOutputs();
    String funcName = frag->GetPixelShaderFunctionName();

    textWriter->WriteFormatted("    %s(", funcName.AsCharPtr());
    if (inputs.Size() > 0)
    {
        IndexT i;
        for (i = 0; i < inputs.Size(); i++)
        {
            textWriter->WriteFormatted("%s_%s", shaderNode->GetName().AsCharPtr(), inputs.ValueAtIndex(i).GetName().AsCharPtr());
            if ((i < (inputs.Size() - 1)) || (outputs.Size() > 0))
            {
                textWriter->WriteString(", ");
            }
        }
    }
    if (outputs.Size() > 0)
    {
        IndexT i;
        for (i = 0; i < outputs.Size(); i++)
        {
            textWriter->WriteFormatted("%s_%s", shaderNode->GetName().AsCharPtr(), outputs.ValueAtIndex(i).GetName().AsCharPtr());
            if (i < (outputs.Size() - 1))
            {
                textWriter->WriteString(", ");
            }
        }
    }
    textWriter->WriteString(");\n");
}

//------------------------------------------------------------------------------
/**
    Write the output block for a vertex or pixel shader.
*/
void
ShaderCodeGenerator::WriteShaderReturnValues(const Ptr<TextWriter>& textWriter, 
                                             const Ptr<ShaderNode>& shaderNode, 
                                             const String& outStruct,
                                             const String& outName)
{
    // write output structure definition
    textWriter->WriteString("    //--- function output ---\n");
    textWriter->WriteFormatted("    %s %s;\n", outStruct.AsCharPtr(), outName.AsCharPtr());

    // for each input slot...
    const Dictionary<String, Ptr<ShaderSlot>>& slots = shaderNode->GetInputSlots();
    IndexT i;
    for (i = 0; i < slots.Size(); i++)
    {
        const Ptr<ShaderSlot>& slot = slots.ValueAtIndex(i);
        String leftHandSide;
        String rightHandSide;
        leftHandSide.Format("%s.%s", outName.AsCharPtr(), slot->GetName().AsCharPtr());
        if (slot->GetConnections().Size() > 0)
        {
            n_assert(slot->GetConnections().Size() == 1);
            const Ptr<ShaderSlot>& uplinkSlot = slot->GetConnections()[0];
            const Ptr<ShaderNode>& uplinkNode = this->shader->GetShaderNodes()[uplinkSlot->GetNodeName()];
            if (uplinkNode->GetName() == "Vertex")
            {
                rightHandSide.Format("vsIn.%s", uplinkSlot->GetName().AsCharPtr());
            }
            else if (uplinkNode->GetName() == "Interpolator")
            {
                rightHandSide.Format("psIn.%s", uplinkSlot->GetName().AsCharPtr());
            }
            else if ((uplinkNode->GetName() == "Sampler") ||
                     (uplinkNode->GetName() == "Constant") ||
                     (uplinkNode->GetName() == "Shared"))
            {
                rightHandSide.Format("%s", uplinkSlot->GetName().AsCharPtr());
            }
            else
            {
                rightHandSide.Format("%s_%s", uplinkNode->GetName().AsCharPtr(), uplinkSlot->GetName().AsCharPtr());
            }
        }
        else
        {
            rightHandSide = ">>> SLOT NOT CONNECTED <<<";
        }
        textWriter->WriteFormatted("    %s = %s;\n", leftHandSide.AsCharPtr(), rightHandSide.AsCharPtr());
    }

    // write return statement
    textWriter->WriteFormatted("    return %s;\n", outName.AsCharPtr());
}

//------------------------------------------------------------------------------
/**
    Write the vertex shader functions.
*/
void
ShaderCodeGenerator::WriteVertexShader(const Ptr<TextWriter>& textWriter)
{
    n_assert(this->shader.isvalid());

    textWriter->WriteString("//------------------------------------------------------------------------------\n");
    textWriter->WriteString("/**\n");
    textWriter->WriteString("    Vertex Shader\n");
    textWriter->WriteString("*/\n");

    // write function header
    textWriter->WriteString("vsOutput vsMain(vsInput vsIn)\n");
    textWriter->WriteString("{\n");

    // get shader node in reverse dependency order
    Array<Ptr<ShaderNode>> shaderNodes;
    this->shader->GetUplinkDependencyNodes("Interpolator", ShaderSlot::VertexShader, shaderNodes);
    IndexT i;
    for (i = 0; i < shaderNodes.Size(); i++)
    {
        textWriter->WriteFormatted("    //--- %s ---\n", shaderNodes[i]->GetName().AsCharPtr());
        this->WriteInputVariables(textWriter, shaderNodes[i], ShaderSlot::VertexShader);
        this->WriteOutputVariables(textWriter, shaderNodes[i], ShaderSlot::VertexShader);
        this->WriteVertexShaderFragmentCall(textWriter, shaderNodes[i]);
        textWriter->WriteFormatted("\n");
    }

    // write the function return values
    this->WriteShaderReturnValues(textWriter, this->shader->GetShaderNodes()["Interpolator"], "vsOutput", "vsOut");
    textWriter->WriteString("}\n");
}

//------------------------------------------------------------------------------
/**
    Write the vertex shader functions.
*/
void
ShaderCodeGenerator::WritePixelShader(const Ptr<TextWriter>& textWriter)
{
    n_assert(this->shader.isvalid());

    textWriter->WriteString("//------------------------------------------------------------------------------\n");
    textWriter->WriteString("/**\n");
    textWriter->WriteString("    Pixel Shader\n");
    textWriter->WriteString("*/\n");

    // write function header
    textWriter->WriteString("psOutput psMain(vsOutput psIn)\n");
    textWriter->WriteString("{\n");

    // get shader node in reverse dependency order
    Array<Ptr<ShaderNode>> shaderNodes;
    this->shader->GetUplinkDependencyNodes("Result", ShaderSlot::PixelShader, shaderNodes);
    IndexT i;
    for (i = 0; i < shaderNodes.Size(); i++)
    {
        textWriter->WriteFormatted("    //--- %s ---\n", shaderNodes[i]->GetName().AsCharPtr());
        this->WriteInputVariables(textWriter, shaderNodes[i], ShaderSlot::PixelShader);
        this->WriteOutputVariables(textWriter, shaderNodes[i], ShaderSlot::PixelShader);
        this->WritePixelShaderFragmentCall(textWriter, shaderNodes[i]);
        textWriter->WriteFormatted("\n");
    }

    // write the function return values
    this->WriteShaderReturnValues(textWriter, this->shader->GetShaderNodes()["Result"], "psOutput", "psOut");
    textWriter->WriteString("}\n");
}

} // namespace Tools